import random
from typing import List, Tuple, Dict, Any
import numpy as np
import torch
# import torch.nn.functional as F # Not explicitly used now
import json
from tqdm import tqdm
from collections import Counter
import argparse
import math
import os
import csv # For CSV output
from sentence_transformers import SentenceTransformer
# Heads import (ensure this path is correct in your environment)
from heads import get_matching_head

SEED = 42
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"

random.seed(SEED)
np.random.seed(SEED)
torch.manual_seed(SEED)
if torch.cuda.is_available():
    torch.cuda.manual_seed_all(SEED)


class MatchingInference:
    def __init__(self, model_dir):
        self.embedding_model = SentenceTransformer(f"{model_dir}/embedding_model", trust_remote_code=True, device=DEVICE)
        self.embedding_model.eval()

        embedding_dim = self.embedding_model.get_sentence_embedding_dimension()
        self.matching_head = get_matching_head("cos_sim", embedding_dim) # Ensure get_matching_head is correctly imported/defined
        self.matching_head.load_state_dict(torch.load(f"{model_dir}/matching_head.pt", map_location=DEVICE))
        self.matching_head = self.matching_head.to(DEVICE)
        self.matching_head.eval()

        self.tokenid2emb = self._build_token_embedding_cache(model_dir)

    def _build_token_embedding_cache(self, model_dir):
        cache_path = os.path.join(model_dir, "tokenid_embedding_cache.pt")
        if os.path.exists(cache_path):
            print(f"📦 Loading token ID embedding cache from {cache_path} ...")
            tokenid2emb_raw = torch.load(cache_path, map_location=DEVICE)
            return {int(token_id): emb.to(DEVICE) for token_id, emb in tokenid2emb_raw.items()}
        else:
            print("⚙️ Building token ID embedding index from tokenizer vocab...")
            tokenizer = self.embedding_model.tokenizer
            vocab = tokenizer.get_vocab()
            filtered_items = [(tok, idx) for tok, idx in vocab.items() if not tok.startswith("[") and tok.strip()]

            tokens = [x[0] for x in filtered_items]
            ids = [x[1] for x in filtered_items]
            
            token_embs = self.embedding_model.encode(tokens, convert_to_tensor=True, show_progress_bar=True, device=DEVICE)
            tokenid2emb = {int(i): emb for i, emb in zip(ids, token_embs)}

            torch.save(tokenid2emb, cache_path)
            print(f"✅ Cached embeddings for {len(tokenid2emb)} token ids to {cache_path}")
            return {int(k): v.to(DEVICE) for k, v in tokenid2emb.items()}

    def encode(self, text: str) -> torch.Tensor:
        return self.embedding_model.encode(text, convert_to_tensor=True, device=DEVICE)

    def score(self, emb_a: torch.Tensor, emb_b: torch.Tensor) -> float:
        emb_a = emb_a.to(DEVICE)
        emb_b = emb_b.to(DEVICE)
        features = {
            "embedding_a": emb_a.unsqueeze(0),
            "embedding_b": emb_b.unsqueeze(0)
        }
        with torch.no_grad():
            logits = self.matching_head(features)["logits"]
            return torch.sigmoid(logits).item()

    @torch.no_grad()
    def predict_batch(self, answers, reasons, batch_size=32):
        assert len(answers) == len(reasons)
        all_probs = []
        for idx in range(0, len(answers), batch_size):
            batch_answers = answers[idx:idx+batch_size]
            batch_reasons = reasons[idx:idx+batch_size]

            emb_a = self.embedding_model.encode(batch_answers, convert_to_tensor=True, normalize_embeddings=True, device=DEVICE)
            emb_b = self.embedding_model.encode(batch_reasons, convert_to_tensor=True, normalize_embeddings=True, device=DEVICE)

            features = {"embedding_a": emb_a, "embedding_b": emb_b}
            outputs = self.matching_head(features)
            logits = outputs["logits"].squeeze(-1)
            probs = torch.sigmoid(logits)
            all_probs.extend(probs.tolist())
        return all_probs

    def score_pair(self, reason: str, answer: str) -> float:
        probs = self.predict_batch([answer], [reason], batch_size=1)
        return probs[0]


# -------- 模型适配接口 (for score_a calculation) --------
def make_model_a_scorer(infer_a: MatchingInference):
    tokenizer = infer_a.embedding_model.tokenizer
    def model_a_score(token_ids: List[int], sentence_text: str) -> float:
        valid_embs = [infer_a.tokenid2emb[tid] for tid in token_ids if tid in infer_a.tokenid2emb]
        if not valid_embs:
            return 0.0
        token_emb = torch.mean(torch.stack(valid_embs), dim=0)
        sent_emb = infer_a.encode(sentence_text)
        
        token_emb = token_emb.to(DEVICE)
        sent_emb = sent_emb.to(DEVICE)

        features = {
            "embedding_a": token_emb.unsqueeze(0),
            "embedding_b": sent_emb.unsqueeze(0)
        }
        with torch.no_grad():
            logits = infer_a.matching_head(features)["logits"]
            score = torch.sigmoid(logits).item()
        return score
    return model_a_score, tokenizer

# -------- Rule-based Sufficiency Check --------
def _check_sufficiency_rule_based(
    current_scores_a: List[float],
    current_scores_b: List[float],
    verification_threshold: float
) -> Tuple[bool, float, float]:
    if not current_scores_a or not current_scores_b:
        return False, 0.0, 0.0
    
    if len(current_scores_a) != len(current_scores_b):
        print(f"Warning: Mismatch in score list lengths for rule-based verifier. A: {len(current_scores_a)}, B: {len(current_scores_b)}")
        return False, 0.0, 0.0

    avg_score_a = sum(current_scores_a) / len(current_scores_a)
    avg_score_b = sum(current_scores_b) / len(current_scores_b)
    
    sufficient = (avg_score_a >= verification_threshold and avg_score_b >= verification_threshold)
    
    return sufficient, avg_score_a, avg_score_b

# -------- Pre-computation and Validation Logic --------
def preprocess_samples_for_efficient_validation(
    data: List[Dict],
    infer_a: MatchingInference,
    infer_b: MatchingInference,
    token_ratio: float
) -> List[Dict]:
    print("🚀 Pre-processing samples to compute block scores (rule-based)...")
    preprocessed_data = []
    model_a_scorer, tokenizer_a = make_model_a_scorer(infer_a)

    for item in tqdm(data, desc="Pre-calculating block scores"):
        P, R_sentences, A_text = item["P"], item["R"], item["A"]

        filtered_sentences = [s for s in R_sentences if s and s.strip()]
        random.shuffle(filtered_sentences) 
        alpha_count = len(filtered_sentences)

        if not filtered_sentences:
            preprocessed_data.append({
                "label": item.get("label"),
                "block_scores": [],
                "alpha": 0 
            })
            continue
        
        current_sample_block_scores = []
        for block_text in filtered_sentences:
            encoding = tokenizer_a(block_text, add_special_tokens=False, return_tensors='pt')
            block_token_ids = encoding["input_ids"][0].tolist()

            if not block_token_ids:
                continue

            sample_size = max(1, int(len(block_token_ids) * token_ratio))
            sample_size = min(sample_size, len(block_token_ids))
            selected_ids = random.sample(block_token_ids, sample_size)
            
            try:
                score_a = model_a_scorer(selected_ids, block_text)
                score_b = infer_b.score_pair(block_text, A_text)
                current_sample_block_scores.append({'a': score_a, 'b': score_b})
            except Exception as e:
                # print(f"Warning: Error scoring block '{block_text[:30]}...': {e}. Skipping block.")
                continue
        
        preprocessed_data.append({
            "label": item.get("label"),
            "block_scores": current_sample_block_scores,
            "alpha": alpha_count
        })
    return preprocessed_data


def run_validation_on_preprocessed_rule_based(
    preprocessed_sample_data: Dict,
    probing_ratio: float,
    verification_threshold: float
) -> Tuple[bool, int]: # MODIFIED: Return (prediction, num_rounds_for_this_sample)
    """
    Runs the rule-based validation logic for a single sample using its pre-computed block scores.
    Returns a tuple: (predicted_sufficiency, number_of_blocks_validated).
    """
    block_scores = preprocessed_sample_data["block_scores"]
    alpha = preprocessed_sample_data["alpha"]

    num_blocks_validated_for_this_sample = 0 # MODIFIED: Initialize counter

    if not block_scores:
        return False, num_blocks_validated_for_this_sample # Returns (False, 0)

    num_initial_blocks_to_check_target = min(alpha, math.ceil(max(1.0, probing_ratio * alpha)))
    if alpha == 0:
        num_initial_blocks_to_check_target = 0
    
    accumulated_scores_a = []
    accumulated_scores_b = []
    pred_sufficient = False

    for i, score_pair in enumerate(block_scores):
        accumulated_scores_a.append(score_pair['a'])
        accumulated_scores_b.append(score_pair['b'])
        num_blocks_validated_for_this_sample = i + 1 # MODIFIED: Increment for each block processed
        
        if (i + 1) >= num_initial_blocks_to_check_target and num_initial_blocks_to_check_target > 0 :
            is_sufficient, _, _ = _check_sufficiency_rule_based(
                accumulated_scores_a, accumulated_scores_b, verification_threshold
            )
            if is_sufficient:
                pred_sufficient = True
                break 
    
    if not pred_sufficient:
        if accumulated_scores_a: # Check only if scores were accumulated
            is_sufficient, _, _ = _check_sufficiency_rule_based(
                accumulated_scores_a, accumulated_scores_b, verification_threshold
            )
            pred_sufficient = is_sufficient
        # else: pred_sufficient remains False (handles empty accumulated_scores_a if block_scores was non-empty but all initial checks failed and target was 0)
            
    return pred_sufficient, num_blocks_validated_for_this_sample # MODIFIED: Return count

# -------- 主函数入口 --------
if __name__ == "__main__":
    parser = argparse.ArgumentParser(description="Rule-based validation with efficient parameter sweeping.")
    parser.add_argument("--model_a_dir", type=str, required=True, help="Directory for model A components.")
    parser.add_argument("--model_b_dir", type=str, required=True, help="Directory for model B components (used for R-A scoring).")
    parser.add_argument("--data_path", type=str, required=True, help="Path to the input JSON data file.")
    # MODIFIED: More specific name for accuracy CSV
    parser.add_argument("--accuracy_csv_output_path", type=str, required=True, help="Path to save the accuracy CSV results.")
    # NEW: Command-line argument for average rounds CSV
    parser.add_argument("--rounds_csv_output_path", type=str, required=True, help="Path to save the average validation rounds CSV results.")
    parser.add_argument("--token_ratio", type=float, default=0.1, help="Theta: Ratio of tokens to sample within a block for model A's score_a.")
    
    parser.add_argument("--verification_thresholds", type=float, nargs='+', required=True, help="List of verification thresholds (for avg_score_a AND avg_score_b).")
    parser.add_argument("--probing_ratios", type=float, nargs='+', required=True, help="List of probing ratios (gamma) to test for early stopping.")
    parser.add_argument("--max_samples", type=int, default=None, help="Process a maximum number of samples for quick testing (e.g., 100). Default: all.")

    args = parser.parse_args()

    print(f"Using device: {DEVICE}")

    print("🚀 Initializing Inference Engines (MatchingInference)...")
    infer_a = MatchingInference(args.model_a_dir)
    infer_b = MatchingInference(args.model_b_dir)
    print("✅ Inference Engines Ready.")

    with open(args.data_path, "r", encoding="utf-8") as fin:
        all_data = json.load(fin)

    if args.max_samples is not None and args.max_samples > 0:
        print(f"🔪 Using a subset of {args.max_samples} samples for processing.")
        data_subset = all_data[:args.max_samples]
    else:
        data_subset = all_data
    
    preprocessed_samples = preprocess_samples_for_efficient_validation(
        data_subset, infer_a, infer_b, args.token_ratio
    )

    accuracy_results_for_csv = []
    avg_rounds_results_for_csv = [] # NEW: List to store average rounds data

    print(f"\n⚙️ Starting rule-based multi-parameter validation for {len(args.verification_thresholds)} thresholds and {len(args.probing_ratios)} probing ratios...")
    
    num_total_samples_processed = len(preprocessed_samples) # Used for averaging rounds

    for v_thresh in tqdm(args.verification_thresholds, desc="Thresholds"):
        for p_ratio in tqdm(args.probing_ratios, desc="Probing Ratios", leave=False):
            correct_predictions = 0
            total_labeled_samples_for_acc = 0
            total_validation_rounds_for_combo = 0 # NEW: Accumulator for rounds for this param combo
            
            for sample_data in preprocessed_samples: # Iterate through all preprocessed samples
                label = sample_data["label"]
                
                # Get the prediction AND number of rounds for this sample
                pred_is_sufficient, num_rounds_this_sample = run_validation_on_preprocessed_rule_based( # MODIFIED
                    sample_data, p_ratio, v_thresh
                )
                
                total_validation_rounds_for_combo += num_rounds_this_sample # NEW: Accumulate rounds

                # Accuracy calculation (only for labeled samples)
                if label is not None:
                    total_labeled_samples_for_acc += 1
                    if pred_is_sufficient == label: # Assuming label is boolean (True/False)
                        correct_predictions += 1
            
            accuracy = (correct_predictions / total_labeled_samples_for_acc) if total_labeled_samples_for_acc > 0 else 0.0
            
            accuracy_results_for_csv.append({
                "verification_threshold": v_thresh,
                "probing_ratio": p_ratio,
                "accuracy": f"{accuracy:.4f}",
                "correct_predictions": correct_predictions,
                "total_labeled_samples": total_labeled_samples_for_acc
            })

            # NEW: Calculate and store average rounds for this combo
            avg_rounds_for_combo = (total_validation_rounds_for_combo / num_total_samples_processed) \
                                   if num_total_samples_processed > 0 else 0.0
            avg_rounds_results_for_csv.append({
                "verification_threshold": v_thresh,
                "probing_ratio": p_ratio,
                "avg_validation_rounds": f"{avg_rounds_for_combo:.2f}" # Store as formatted string or float
            })

    # --- Stage 3a: Output accuracy results to CSV ---
    if accuracy_results_for_csv:
        fieldnames = accuracy_results_for_csv[0].keys()
        # MODIFIED: output path for accuracy
        with open(args.accuracy_csv_output_path, "w", newline='', encoding="utf-8") as fout_csv:
            writer = csv.DictWriter(fout_csv, fieldnames=fieldnames)
            writer.writeheader()
            writer.writerows(accuracy_results_for_csv)
        print(f"\n✅ Accuracy results for different combinations saved to {args.accuracy_csv_output_path}")
    else:
        print("\n⚠️ No accuracy results to save to CSV. Check data or parameters (ensure labeled data exists).")

    # --- Stage 3b: Output average rounds results to a NEW CSV ---
    if avg_rounds_results_for_csv:
        fieldnames_rounds = avg_rounds_results_for_csv[0].keys()
        # NEW: output path for average rounds
        with open(args.rounds_csv_output_path, "w", newline='', encoding="utf-8") as fout_csv_rounds:
            writer = csv.DictWriter(fout_csv_rounds, fieldnames=fieldnames_rounds)
            writer.writeheader()
            writer.writerows(avg_rounds_results_for_csv)
        print(f"\n✅ Average validation rounds for different combinations saved to {args.rounds_csv_output_path}")
    else:
        print("\n⚠️ No average rounds results to save to CSV. Check data or parameters.")


    print("\n----- Rule-Based Multi-Parameter Validation Summary (Accuracy) -----")
    for res in accuracy_results_for_csv:
        print(f"Threshold: {res['verification_threshold']:.3f}, Probing Ratio: {res['probing_ratio']:.2f}, Accuracy: {res['accuracy']} ({res['correct_predictions']}/{res['total_labeled_samples']})")

    print("\n----- Rule-Based Multi-Parameter Validation Summary (Average Rounds) -----") # NEW
    for res in avg_rounds_results_for_csv:
        print(f"Threshold: {res['verification_threshold']:.3f}, Probing Ratio: {res['probing_ratio']:.2f}, Avg. Rounds: {res['avg_validation_rounds']}")